{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "31bb147b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -i \"../util/lang_utils.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b3fe070d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from spacy.cli.train import train\n",
    "from spacy.cli.evaluate import evaluate\n",
    "from spacy.cli.debug_data import debug_data\n",
    "from spacy.tokens import DocBin\n",
    "from sklearn.metrics import classification_report\n",
    "# Config generated at https://spacy.io/usage/training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d4195a05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_data_entry(input_text, label, label_list):\n",
    "    doc = small_model(input_text)\n",
    "    cats = [0] * len(label_list)\n",
    "    cats[label] = 1\n",
    "    final_cats = {}\n",
    "    for i, label in enumerate(label_list):\n",
    "        final_cats[label] = cats[i]\n",
    "    doc.cats = final_cats\n",
    "    return doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "029ca3dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and prepare data\n",
    "train_db = DocBin()\n",
    "test_db = DocBin()\n",
    "label_list = [\"tech\", \"business\", \"sport\", \"entertainment\", \"politics\"]\n",
    "train_df = pd.read_json(\"../data/bbc_train.json\")\n",
    "test_df = pd.read_json(\"../data/bbc_test.json\")\n",
    "train_df.sample(frac=1)\n",
    "for idx, row in train_df.iterrows():\n",
    "    text = row[\"text\"]\n",
    "    label = row[\"label\"]\n",
    "    doc = preprocess_data_entry(text, label, label_list)\n",
    "    train_db.add(doc)\n",
    "for idx, row in test_df.iterrows():\n",
    "    text = row[\"text\"]\n",
    "    label = row[\"label\"]\n",
    "    doc = preprocess_data_entry(text, label, label_list)\n",
    "    test_db.add(doc)\n",
    "train_db.to_disk('../data/bbc_train.spacy')\n",
    "test_db.to_disk('../data/bbc_test.spacy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "a6614a01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[38;5;4mℹ Saving to output directory: ../models/spacy_textcat_bbc\u001b[0m\n",
      "\u001b[38;5;4mℹ Using CPU\u001b[0m\n",
      "\u001b[1m\n",
      "=========================== Initializing pipeline ===========================\u001b[0m\n",
      "\u001b[38;5;2m✔ Initialized pipeline\u001b[0m\n",
      "\u001b[1m\n",
      "============================= Training pipeline =============================\u001b[0m\n",
      "\u001b[38;5;4mℹ Pipeline: ['tok2vec', 'textcat']\u001b[0m\n",
      "\u001b[38;5;4mℹ Initial learn rate: 0.001\u001b[0m\n",
      "E    #       LOSS TOK2VEC  LOSS TEXTCAT  CATS_SCORE  SCORE \n",
      "---  ------  ------------  ------------  ----------  ------\n",
      "  0       0          0.00          0.16        8.48    0.08\n",
      "  0     200         20.77         37.26       35.58    0.36\n",
      "  0     400         98.56         35.96       26.90    0.27\n",
      "  0     600         49.83         37.31       36.60    0.37\n",
      "  0     800         96.46         27.11       38.64    0.39\n",
      "  0    1000        102.35         22.53       43.11    0.43\n",
      "  1    1200        101.42         23.68       61.93    0.62\n",
      "  1    1400         50.70         19.16       55.75    0.56\n",
      "  1    1600        224.28         15.09       46.57    0.47\n",
      "  1    1800        354.78         18.86       84.20    0.84\n",
      "  1    2000        131.65         16.03       64.69    0.65\n",
      "  2    2200        171.85         13.44       78.76    0.79\n",
      "  2    2400        290.96         13.70       78.22    0.78\n",
      "  2    2600         93.77         15.27       67.76    0.68\n",
      "  2    2800        467.41         12.75       72.55    0.73\n",
      "  2    3000       3693.98         15.34       72.21    0.72\n",
      "  3    3200       1208.15         10.53       84.84    0.85\n",
      "  3    3400       1796.60          9.13       87.51    0.88\n",
      "  3    3600        777.35          7.71       93.34    0.93\n",
      "  3    3800       1498.47          7.51       88.65    0.89\n",
      "  3    4000       3671.99         10.10       82.50    0.82\n",
      "  4    4200       7338.34          8.25       79.49    0.79\n",
      "  4    4400       2924.32          9.85       86.56    0.87\n",
      "  4    4600       2889.25          7.67       86.86    0.87\n",
      "  4    4800       7571.47          9.64       80.25    0.80\n",
      "  4    5000      16164.99         10.58       87.71    0.88\n",
      "  5    5200       8604.43          8.20       84.98    0.85\n",
      "\u001b[38;5;2m✔ Saved pipeline to output directory\u001b[0m\n",
      "../models/spacy_textcat_bbc/model-last\n"
     ]
    }
   ],
   "source": [
    "train(\"../data/spacy_config.cfg\", output_path=\"../models/spacy_textcat_bbc\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8a4b5ddd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lib dems  new election pr chief the lib dems have appointed a senior figure from bt to be the party s new communications chief for their next general election effort.  sandy walkington will now work with senior figures such as matthew taylor on completing the party manifesto. party chief executive lord rennard said the appointment was a  significant strengthening of the lib dem team . mr walkington said he wanted the party to be ready for any  mischief  rivals or the media tried to throw at it.   my role will be to ensure this new public profile is effectively communicated at all levels   he said.  i also know the party will be put under scrutiny in the media and from the other parties as never before - and we will need to show ourselves ready and prepared to counter the mischief and misrepresentation that all too often comes from the party s opponents.  the party is already demonstrating on every issue that it is the effective opposition.  mr walkington s new job title is director of general election communications.\n",
      "8    politics\n",
      "Name: label_text, dtype: object\n",
      "Predicted probabilities:  {'tech': 3.531841841208916e-08, 'business': 0.000641813559923321, 'sport': 0.00033847044687718153, 'entertainment': 0.00016174423217307776, 'politics': 0.9988579750061035}\n"
     ]
    }
   ],
   "source": [
    "# Use the trained model\n",
    "nlp = spacy.load(\"../models/spacy_textcat_bbc/model-last\")\n",
    "input_text = test_df.iloc[1, test_df.columns.get_loc('text')]\n",
    "print(input_text)\n",
    "print(test_df[\"label_text\"].iloc[[1]])\n",
    "doc = nlp(input_text)\n",
    "print(\"Predicted probabilities: \", doc.cats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6ee412f8-4ff3-4603-a8da-4cbf04b9b613",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the model on test data\n",
    "def get_prediction(input_text, nlp_model, target_names):\n",
    "    doc = nlp_model(input_text)\n",
    "    category = max(doc.cats, key = doc.cats.get)\n",
    "    return target_names.index(category)\n",
    "test_df[\"prediction\"] = test_df[\"text\"].apply(lambda x: get_prediction(x, nlp, label_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b9d1e228-eee2-4b6e-8e1b-63dcb3d7dffa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "         tech       0.82      0.94      0.87        80\n",
      "     business       0.94      0.83      0.89       102\n",
      "        sport       0.89      0.89      0.89       102\n",
      "entertainment       0.94      0.87      0.91        77\n",
      "     politics       0.78      0.83      0.80        84\n",
      "\n",
      "     accuracy                           0.87       445\n",
      "    macro avg       0.87      0.87      0.87       445\n",
      " weighted avg       0.88      0.87      0.87       445\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(test_df[\"label\"], test_df[\"prediction\"], target_names=label_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cdeb36d0-fd80-45ff-83cf-45f05e5edaa6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'token_acc': 1.0,\n",
       " 'token_p': 1.0,\n",
       " 'token_r': 1.0,\n",
       " 'token_f': 1.0,\n",
       " 'cats_score': 0.8719339318444819,\n",
       " 'cats_score_desc': 'macro F',\n",
       " 'cats_micro_p': 0.8719101123595505,\n",
       " 'cats_micro_r': 0.8719101123595505,\n",
       " 'cats_micro_f': 0.8719101123595505,\n",
       " 'cats_macro_p': 0.8746516896205309,\n",
       " 'cats_macro_r': 0.8732906799083269,\n",
       " 'cats_macro_f': 0.8719339318444819,\n",
       " 'cats_macro_auc': 0.9800144873453936,\n",
       " 'cats_f_per_type': {'tech': {'p': 0.8152173913043478,\n",
       "   'r': 0.9375,\n",
       "   'f': 0.872093023255814},\n",
       "  'business': {'p': 0.9444444444444444,\n",
       "   'r': 0.8333333333333334,\n",
       "   'f': 0.8854166666666667},\n",
       "  'sport': {'p': 0.8921568627450981,\n",
       "   'r': 0.8921568627450981,\n",
       "   'f': 0.8921568627450981},\n",
       "  'entertainment': {'p': 0.9436619718309859,\n",
       "   'r': 0.8701298701298701,\n",
       "   'f': 0.9054054054054054},\n",
       "  'politics': {'p': 0.7777777777777778,\n",
       "   'r': 0.8333333333333334,\n",
       "   'f': 0.8045977011494253}},\n",
       " 'cats_auc_per_type': {'tech': 0.9842808219178081,\n",
       "  'business': 0.9824501229063054,\n",
       "  'sport': 0.9933544846510032,\n",
       "  'entertainment': 0.9834839073969509,\n",
       "  'politics': 0.9565030998549005},\n",
       " 'speed': 6894.989948433934}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate('../models/spacy_textcat_bbc/model-last', '../data/bbc_test.spacy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b2ad1f-ae68-4390-8cdb-16ff9c12f98e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
